{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    },
    "tags": []
   },
   "source": [
    "## Copyright 2022 Google LLC. Double-click for license information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Copyright 2022 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#      http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    },
    "tags": []
   },
   "source": [
    "## Prompt-to-Prompt with Stable Diffusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-04-01 13:23:53,975 - modelscope - INFO - PyTorch version 2.1.2 Found.\n",
      "2024-04-01 13:23:53,978 - modelscope - INFO - Loading ast index from /home/smotamed/.cache/modelscope/ast_indexer\n",
      "2024-04-01 13:23:53,998 - modelscope - INFO - Loading done! Current index file version is 1.4.2, with md5 be51e39fac31bdb24de5e4fbf8d346a2 and a total number of 842 components indexed\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing the conversion map\n",
      "No module named 'tensorflow'\n"
     ]
    }
   ],
   "source": [
    "from typing import Optional, Union, Tuple, List, Callable, Dict\n",
    "import torch\n",
    "from diffusers import StableDiffusionPipeline\n",
    "from diffusers.models.attention_processor import AttnProcessor, Attention\n",
    "import torch.nn.functional as nnf\n",
    "import numpy as np\n",
    "import sys\n",
    "import abc\n",
    "import fastcore.all as fc\n",
    "import math\n",
    "from skimage.draw import disk\n",
    "import torch.nn.functional as F\n",
    "from functools import partial\n",
    "from utils.guidance_functions import *\n",
    "import ptp_utils\n",
    "import numpy\n",
    "from compel import Compel\n",
    "import diffusers\n",
    "import matplotlib.pyplot as plt\n",
    "from diffusers import DPMSolverMultistepScheduler, TextToVideoSDPipeline, UNet3DConditionModel\n",
    "from diffusers import DiffusionPipeline\n",
    "from diffusers.utils import export_to_video\n",
    "from einops import rearrange\n",
    "from huggingface_hub import snapshot_download\n",
    "import sys\n",
    "\n",
    "from modelscope.pipelines import pipeline\n",
    "from modelscope.outputs import OutputKeys\n",
    "import pathlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "For loading the Stable Diffusion using Diffusers, follow the instuctions https://huggingface.co/blog/stable_diffusion and update ```MY_TOKEN``` with your token.\n",
    "Set ```LOW_RESOURCE``` to ```True``` for running on 12GB GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['/home/smotamed/Video-Editing-X-Attention', '/home/smotamed/anaconda3/envs/text2video-finetune/lib/python310.zip', '/home/smotamed/anaconda3/envs/text2video-finetune/lib/python3.10', '/home/smotamed/anaconda3/envs/text2video-finetune/lib/python3.10/lib-dynload', '', '/home/smotamed/.local/lib/python3.10/site-packages', '/home/smotamed/anaconda3/envs/text2video-finetune/lib/python3.10/site-packages', '/tmp/tmp873k3ec7']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'prompt = \"darth vader surfing in the ocean\"\\nvideo_frames = pipe(prompt, num_inference_steps=40, height=320, width=504, num_frames=20).frames\\nexport_to_video(video_frames, \"testit_shark.mp4\")'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MY_TOKEN = '<me>'\n",
    "print(sys.path)\n",
    "\n",
    "numpy.set_printoptions(threshold=sys.maxsize)\n",
    "LOW_RESOURCE = False \n",
    "NUM_DIFFUSION_STEPS = 50\n",
    "GUIDANCE_SCALE = 7.5\n",
    "MAX_NUM_WORDS = 77\n",
    "device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n",
    "\n",
    "#ldm_stable = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", use_auth_token=MY_TOKEN).to(device)\n",
    "#tokenizer = ldm_stable.tokenizer\n",
    "from PIL import Image\n",
    "#pipe = DiffusionPipeline.from_pretrained(\"damo-vilab/text-to-video-ms-1.7b\", torch_dtype=torch.float16, variant=\"fp16\")\n",
    "#pipe = DiffusionPipeline.from_pretrained(\"cerspense/zeroscope_v2_576w\", torch_dtype=torch.float16)\n",
    "#pipe.enable_model_cpu_offload()\n",
    "pipe = DiffusionPipeline.from_pretrained(\"damo-vilab/text-to-video-ms-1.7b\", torch_dtype=torch.float16, variant=\"fp16\")\n",
    "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
    "pipe.enable_model_cpu_offload()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    },
    "tags": []
   },
   "source": [
    "## Prompt-to-Prompt Attnetion Controllers\n",
    "Our main logic is implemented in the `forward` call in an `AttentionControl` object.\n",
    "The forward is called in each attention layer of the diffusion model and it can modify the input attnetion weights `attn`.\n",
    "\n",
    "`is_cross`, `place_in_unet in (\"down\", \"mid\", \"up\")`, `AttentionControl.cur_step` help us track the exact attention layer and timestamp during the diffusion iference.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def normalize(x): return (x - x.min()) / (x.max() - x.min())\n",
    "def threshold_attention(attn, s=10):\n",
    "    norm_attn = s * (normalize(attn) - 0.5)\n",
    "    return normalize(norm_attn.sigmoid())\n",
    "\n",
    "def get_shape(attn, s=20): \n",
    "    return threshold_attention(attn, s)\n",
    "\n",
    "def get_size(attn): \n",
    "    return 1/attn.shape[-2] * threshold_attention(attn).sum((1,2)).mean()\n",
    "\n",
    "def enlarge(x, scale_factor=1.0):\n",
    "    x = x.view(1, -1, 1)\n",
    "    assert scale_factor >= 1\n",
    "\n",
    "    h = w = int(math.sqrt(x.shape[-2]))\n",
    "    x = rearrange(x, 'n (h w) d -> n d h w', h=h)\n",
    "    x = F.interpolate(x, scale_factor=scale_factor)\n",
    "    new_h = new_w = x.shape[-1]\n",
    "    x_l, x_r = (new_w//2) - w//2, (new_w//2) + w//2\n",
    "    x_t, x_b = (new_h//2) - h//2, (new_h//2) + h//2\n",
    "    x = x[:,:,x_t:x_b,x_l:x_r]\n",
    "    return rearrange(x, 'n d h w -> n (h w) d', h=h) * scale_factor\n",
    "\n",
    "def create_circular_mask(h, w, center=None, radius=None):\n",
    "\n",
    "    if center is None: # use the middle of the image\n",
    "        center = (int(w/2), int(h/2))\n",
    "    if radius is None: # use the smallest distance between the center and image walls\n",
    "        radius = min(center[0], center[1], w-center[0], h-center[1])\n",
    "\n",
    "    Y, X = np.ogrid[:h, :w]\n",
    "    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)\n",
    "\n",
    "    mask = dist_from_center <= radius\n",
    "    return mask\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def run_and_display(prompt, obj_to_edit, objects, guidance_func,  latent=None, run_baseline=False):\n",
    "    videos, orig_video, x_t = ptp_utils.text2video(pipe, prompt, obj_to_edit, objects, guidance_func) #num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, low_resource=LOW_RESOURCE)\n",
    "    \n",
    "    return videos, orig_video, x_t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    },
    "tags": []
   },
   "source": [
    "## Replacement edit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/smotamed/anaconda3/envs/text2video-finetune/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss : tensor(2.7852, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6836, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6562, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6387, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6152, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6055, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6035, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.5938, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6074, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6152, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6016, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6055, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.5977, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.5957, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.5957, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.5957, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.5977, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6016, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.5996, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6016, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.5996, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6113, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6211, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6270, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6211, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6172, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6172, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6191, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6152, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6133, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6133, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6133, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n",
      "loss : tensor(2.6152, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Decoding to pixels...: 100%|██████████| 16/16 [00:01<00:00,  9.92frame/s]\n",
      "Decoding to pixels...: 100%|██████████| 16/16 [00:00<00:00, 50.58frame/s]\n"
     ]
    }
   ],
   "source": [
    "prompts = [\"darth vader surfs in the ocean\"]\n",
    "tokens = [2]\n",
    "move = partial(roll_shape, direction='right', factor=.1)\n",
    "guidance = partial(edit_by_E, shape_weight=0, appearance_weight = 0, position_weight=4, tau=move)\n",
    "obj_to_edit = 'darth'\n",
    "objects = ['darth', \"ocean\"]\n",
    "videos, orig_video, x_t  = run_and_display(prompts, obj_to_edit, objects, guidance_func=guidance,  latent=None, run_baseline=False)\n",
    "\n",
    "for video in [videos[0]]:\n",
    "    video = rearrange(video.cpu(), \"c f h w -> f h w c\").clamp(-1, 1).add(1).mul(127.5)\n",
    "    video = video.byte().cpu().numpy()\n",
    "    export_to_video(video, \"edited_video.mp4\")\n",
    "\n",
    "for video in [orig_video[0]]:\n",
    "    video = rearrange(video.cpu(), \"c f h w -> f h w c\").clamp(-1, 1).add(1).mul(127.5)\n",
    "    video = video.byte().cpu().numpy()\n",
    "    export_to_video(video, \"original_video.mp4\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "pytorch-gpu.1-11.m94",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-11:m94"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
